# -*- coding: utf-8 -*-
"""Untitled27.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1B9SZoZrU-cBPcNbOTgAVuCVQKlpso2-R
"""

# -*- coding: utf-8 -*-
"""
Main script to run Matrix Product Approximation Experiments.

This script defines experiment parameters, uses matrix generators and algorithms
from matrix_product_experiment_lib, runs the experiments, and plots results.
"""

import numpy as np
import scipy
import time
import os
import traceback

# Import from the library file
try:
    from matrix_product_approximations_exp2 import (
        generate_matrices_uniform,
        generate_matrices_gaussian_cancellation,
        generate_matrices_row_orthogonal,
        generate_matrices_repeated_cols,
        generate_matrices_nonlinear,
        calculate_rho_g,
        run_algorithm_experiments,
        compute_all_bounds_orchestrator,
        plot_multi_panel_results,
        IMPROVED_STYLES # For passing to plot function if not using default
    )
except ImportError:
    print("ERROR: Could not import from matrix_product_experiment_lib.py.")
    print("Ensure matrix_product_experiment_lib.py is in the same directory or Python path.")
    exit()

# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================
if __name__ == "__main__":
    print("--- CONFIGURATION ---")
    # User updated dimensions: A(n_exp x m_exp), B(p_exp x m_exp)
    # m_exp is the common dimension from which k columns are selected/sketched
    n_exp = 20  # Number of rows in A
    m_exp = 100 # Common dimension (number of columns in A, number of columns in B)
    p_exp = 30  # Number of rows in B

    k_ratio_start = 0.05
    k_ratio_end = 0.45
    num_k_steps = 10
    n_trials_exp = 10
    main_seed = 2024

    plot_directory = f"plots_N{n_exp}_Mcommon{m_exp}_P{p_exp}_MultiPanel_NoLegend_v5_Refactored"

    k_start = max(1, int(k_ratio_start * m_exp))
    k_end = min(m_exp, int(k_ratio_end * m_exp))
    if k_end < k_start: k_end = k_start
    if num_k_steps <=0 : num_k_steps = 1
    if k_start == k_end and num_k_steps > 1: num_k_steps = 1

    k_values_experiment = np.unique(np.linspace(k_start, k_end, num=num_k_steps, dtype=int)) if k_start != k_end else np.array([k_start], dtype=int)
    k_values_experiment = k_values_experiment[(k_values_experiment > 0) & (k_values_experiment <= m_exp)]
    if len(k_values_experiment) == 0: k_values_experiment = np.array([max(1, min(m_exp, k_start))])

    print(f"Matrix dimensions: A({n_exp} x {m_exp}), B({p_exp} x {m_exp})")
    print(f"k values (from {m_exp} common columns): {k_values_experiment}")
    print(f"Plot Directory: {plot_directory}")
    if not os.path.exists(plot_directory):
        os.makedirs(plot_directory)
        print(f"Created directory: {plot_directory}")
    print("-" * 30)

    # Define matrix generators using functions from the library
    # Note: A is (n_exp x m_exp), B is (p_exp x m_exp)
    # The library functions expect (rows_A, rows_B, common_cols, ...)
    matrix_generators_dict = {
        "Uniform_-1_1": lambda seed=None: generate_matrices_uniform(n_exp, p_exp, m_exp, seed=seed),
        "Gaussian_Standard": lambda seed=None: generate_matrices_gaussian_cancellation(n_exp, p_exp, m_exp, seed=seed),
        "Row_Orthogonal": lambda seed=None: generate_matrices_row_orthogonal(n_exp, p_exp, m_exp, seed=seed),
        "Gaussian_Cancel_10pct_Noise05": lambda seed=None: generate_matrices_gaussian_cancellation(n_exp, p_exp, m_exp, cancel_fraction=0.1, noise_level=0.05, seed=seed),
        "RepeatedCols_10pct_n1pct": lambda seed=None: generate_matrices_repeated_cols(n_exp, p_exp, m_exp, repeat_frac=0.1, noise_ratio=0.01, seed=seed),
        "NonLinear_Tanh_GaussBase": lambda seed=None: generate_matrices_nonlinear(n_exp, p_exp, m_exp, base_type='gaussian', func=np.tanh, seed=seed),
    }

    overall_run_start_time = time.time()
    all_experimental_results = {}

    for i_gen, (matrix_gen_key, generator_callable) in enumerate(matrix_generators_dict.items()):
        print(f"\n===== Running Experiment {i_gen+1}/{len(matrix_generators_dict)}: {matrix_gen_key} =====")
        current_exp_seed = main_seed + i_gen
        try:
            # A_curr will be n_exp x m_exp, B_curr will be p_exp x m_exp
            A_curr, B_curr = generator_callable(seed=current_exp_seed)
            if A_curr.shape[1] == 0: # m_exp is 0
                print(f"Skipping {matrix_gen_key} due to empty common dim (m_exp = {m_exp}).")
                continue

            # For Rho_G, A_dense_rho should be (N x n_common), B_dense_rho should be (P x n_common)
            # In our case, A_curr is (n_exp x m_exp), B_curr is (p_exp x m_exp)
            # So, n_common is m_exp.
            A_dense_rho = A_curr.toarray() if scipy.sparse.issparse(A_curr) else A_curr
            B_dense_rho = B_curr.toarray() if scipy.sparse.issparse(B_curr) else B_curr
            rho_g_val = calculate_rho_g(A_dense_rho, B_dense_rho) # Expects (N x n_common), (P x n_common)
            print(f"Rho_G = {rho_g_val:.4f}")

            # run_algorithm_experiments expects A_orig (N x n_common), B_orig (P x n_common)
            algo_errs, experiment_data = run_algorithm_experiments(A_curr, B_curr, k_values_experiment, n_trials_exp, current_exp_seed)
            current_frob_ABt_sq = experiment_data.get('frob_ABt_sq', np.nan)

            if np.isnan(current_frob_ABt_sq):
                print(f"Skipping bounds for {matrix_gen_key} due to exact product failure or zero norm.")
                raw_results = {}
                raw_results.update(algo_errs)
                raw_results['Rho_G'] = rho_g_val
                all_experimental_results[matrix_gen_key] = raw_results
                continue

            # compute_all_bounds_orchestrator expects A_orig_mat (N x n_common), B_orig_mat (P x n_common), common_dim_n
            all_bounds_results = compute_all_bounds_orchestrator(A_curr, B_curr, k_values_experiment, current_frob_ABt_sq, m_exp)

            final_plot_data = {}
            final_plot_data.update(algo_errs)
            final_plot_data.update(all_bounds_results)
            if 'k_values_bounds' in final_plot_data: # Clean up redundant k key from bounds
                del final_plot_data['k_values_bounds']
            final_plot_data['Rho_G'] = rho_g_val
            all_experimental_results[matrix_gen_key] = final_plot_data

        except Exception as e_exp:
            print(f"MAJOR ERROR during experiment for '{matrix_gen_key}': {e_exp}")
            traceback.print_exc()
            # Store minimal data to avoid crashing plot grouping
            all_experimental_results[matrix_gen_key] = {'k': k_values_experiment, 'Rho_G': np.nan}


    # Define plot groups
    plot_groups = [
        {"name": "Group1_Uniform_Gaussian_RowOrthogonal", "keys": ["Uniform_-1_1", "Gaussian_Standard", "Row_Orthogonal"]},
        {"name": "Group2_Cancel_Repeated_Nonlinear", "keys": ["Gaussian_Cancel_10pct_Noise05", "RepeatedCols_10pct_n1pct", "NonLinear_Tanh_GaussBase"]}
    ]

    for group_info in plot_groups:
        results_for_group = []
        names_for_group = []
        for key in group_info["keys"]:
            if key in all_experimental_results and \
               (all_experimental_results[key].get('k') is not None or \
                all_experimental_results[key].get('k_values_bounds') is not None): # Check if k data exists
                results_for_group.append(all_experimental_results[key])
                names_for_group.append(key)
            else:
                print(f"Warning: Data for '{key}' not found or incomplete. Excluding from plot group '{group_info['name']}'.")

        if results_for_group:
            print(f"\n--- Plotting: {group_info['name']} (No Legends) ---")
            plot_multi_panel_results(
                results_list=results_for_group,
                matrix_type_names=names_for_group,
                figure_title_prefix=group_info["name"],
                common_dim_for_k_ratio=m_exp, # m_exp is the common dimension for k ratio
                styles_dict=IMPROVED_STYLES, # Pass styles from lib
                plot_dir_path=plot_directory,
                log_scale_y_axis=True
            )
        else:
            print(f"No data to plot for group '{group_info['name']}'. Plotting skipped.")

    overall_run_end_time = time.time()
    print(f"\n--- All Experiments Finished in {overall_run_end_time - overall_run_start_time:.2f} seconds ---")

